{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ur8xi4C7S06n"
      },
      "outputs": [],
      "source": [
        "# Copyright 2025 Google LLC\n",
        "#\n",
        "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "#     https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JAPoU8Sm5E6e"
      },
      "source": [
        "# Intro to Model Context Protocol (MCP) integration with Vertex AI\n",
        "\n",
        "<table align=\"left\">\n",
        "  <td style=\"text-align: center\">\n",
        "    <a href=\"https://colab.research.google.com/github/GoogleCloudPlatform/generative-ai/blob/main/gemini/mcp/intro_to_mcp.ipynb\">\n",
        "      <img width=\"32px\" src=\"https://www.gstatic.com/pantheon/images/bigquery/welcome_page/colab-logo.svg\" alt=\"Google Colaboratory logo\"><br> Open in Colab\n",
        "    </a>\n",
        "  </td>\n",
        "  <td style=\"text-align: center\">\n",
        "    <a href=\"https://console.cloud.google.com/vertex-ai/colab/import/https:%2F%2Fraw.githubusercontent.com%2FGoogleCloudPlatform%2Fgenerative-ai%2Fmain%2Fgemini%2Fmcp%2Fintro_to_mcp.ipynb\">\n",
        "      <img width=\"32px\" src=\"https://lh3.googleusercontent.com/JmcxdQi-qOpctIvWKgPtrzZdJJK-J3sWE1RsfjZNwshCFgE_9fULcNpuXYTilIR2hjwN\" alt=\"Google Cloud Colab Enterprise logo\"><br> Open in Colab Enterprise\n",
        "    </a>\n",
        "  </td>\n",
        "  <td style=\"text-align: center\">\n",
        "    <a href=\"https://console.cloud.google.com/vertex-ai/workbench/deploy-notebook?download_url=https://raw.githubusercontent.com/GoogleCloudPlatform/generative-ai/main/gemini/mcp/intro_to_mcp.ipynb\">\n",
        "      <img src=\"https://www.gstatic.com/images/branding/gcpiconscolors/vertexai/v1/32px.svg\" alt=\"Vertex AI logo\"><br> Open in Vertex AI Workbench\n",
        "    </a>\n",
        "  </td>\n",
        "  \n",
        "  \n",
        "  <td style=\"text-align: center\">\n",
        "    <a href=\"https://colab.research.google.com/github/GoogleCloudPlatform/generative-ai/blob/main/gemini/mcp/intro_to_mcp.ipynb\">\n",
        "      <img width=\"32px\" src=\"https://raw.githubusercontent.com/primer/octicons/refs/heads/main/icons/mark-github-24.svg\" alt=\"GitHub logo\"><br> View on GitHub\n",
        "    </a>\n",
        "  </td>\n",
        "</table>\n",
        "\n",
        "<div style=\"clear: both;\"></div>\n",
        "\n",
        "<b>Share to:</b>\n",
        "\n",
        "<a href=\"https://www.linkedin.com/sharing/share-offsite/?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/mcp/intro_to_mcp.ipynb\" target=\"_blank\">\n",
        "  <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/8/81/LinkedIn_icon.svg\" alt=\"LinkedIn logo\">\n",
        "</a>\n",
        "\n",
        "<a href=\"https://bsky.app/intent/compose?text=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/mcp/intro_to_mcp.ipynb\" target=\"_blank\">\n",
        "  <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/7/7a/Bluesky_Logo.svg\" alt=\"Bluesky logo\">\n",
        "</a>\n",
        "\n",
        "<a href=\"https://twitter.com/intent/tweet?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/mcp/intro_to_mcp.ipynb\" target=\"_blank\">\n",
        "  <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/5/5a/X_icon_2.svg\" alt=\"X logo\">\n",
        "</a>\n",
        "\n",
        "<a href=\"https://reddit.com/submit?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/mcp/intro_to_mcp.ipynb\" target=\"_blank\">\n",
        "  <img width=\"20px\" src=\"https://redditinc.com/hubfs/Reddit%20Inc/Brand/Reddit_Logo.png\" alt=\"Reddit logo\">\n",
        "</a>\n",
        "\n",
        "<a href=\"https://www.facebook.com/sharer/sharer.php?u=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/mcp/intro_to_mcp.ipynb\" target=\"_blank\">\n",
        "  <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/5/51/Facebook_f_logo_%282019%29.svg\" alt=\"Facebook logo\">\n",
        "</a>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "84f0f73a0f76"
      },
      "source": [
        "| Author |\n",
        "| --- |\n",
        "| [Dave Wang](https://github.com/wadave) |"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tvgnzT1CKxrO"
      },
      "source": [
        "## Overview\n",
        "The Model Context Protocol (MCP) is an open standard that streamlines the integration of AI assistants with external data sources, tools, and systems. [MCP standardizes how applications provide context to LLMs](https://modelcontextprotocol.io/introduction). MCP establishes the essential standardized interface allowing AI models to connect directly with diverse external systems and services.\n",
        "\n",
        "Developers have the option to use third-party MCP servers or create custom ones when building applications. \n",
        "\n",
        "\n",
        "This notebook shows two ways to use MCP with Vertex AI\n",
        "- Build a custom MCP server, and use it with Gemini on Vertex AI\n",
        "- Use pre-built MCP server with Vertex AI"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "61RBz8LLbxCR"
      },
      "source": [
        "## Get started"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "No17Cw5hgx12"
      },
      "source": [
        "### Install Google Gen AI SDK and other required packages\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "tFy3H3aPgx12"
      },
      "outputs": [],
      "source": [
        "%pip install --upgrade --quiet google-genai mcp geopy uv"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dmWOrTJ3gx13"
      },
      "source": [
        "### Authenticate your notebook environment (Colab only)\n",
        "\n",
        "If you're running this notebook on Google Colab, run the cell below to authenticate your environment."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "NyKGtVQjgx13"
      },
      "outputs": [],
      "source": [
        "import sys\n",
        "\n",
        "if \"google.colab\" in sys.modules:\n",
        "    from google.colab import auth\n",
        "\n",
        "    auth.authenticate_user()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DF4l8DTdWgPY"
      },
      "source": [
        "### Set Google Cloud project information\n",
        "\n",
        "To get started using Vertex AI, you must have an existing Google Cloud project and [enable the Vertex AI API](https://console.cloud.google.com/flows/enableapi?apiid=aiplatform.googleapis.com).\n",
        "\n",
        "Learn more about [setting up a project and a development environment](https://cloud.google.com/vertex-ai/docs/start/cloud-environment)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Nqwi-5ufWp_B"
      },
      "outputs": [],
      "source": [
        "# Use the environment variable if the user doesn't provide Project ID.\n",
        "import os\n",
        "\n",
        "from google import genai\n",
        "\n",
        "# TODO set  up  your own project id\n",
        "PROJECT_ID = \"[your-project-id]\"  # @param {type: \"string\", placeholder: \"[your-project-id]\", isTemplate: true}\n",
        "if not PROJECT_ID or PROJECT_ID == \"[your-project-id]\":\n",
        "    PROJECT_ID = str(os.environ.get(\"GOOGLE_CLOUD_PROJECT\"))\n",
        "\n",
        "LOCATION = os.environ.get(\"GOOGLE_CLOUD_REGION\", \"us-central1\")\n",
        "\n",
        "client = genai.Client(vertexai=True, project=PROJECT_ID, location=LOCATION)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5303c05f7aa6"
      },
      "source": [
        "### Import libraries"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "id": "6fc324893334"
      },
      "outputs": [],
      "source": [
        "from typing import Any\n",
        "\n",
        "from google import genai\n",
        "from google.genai import types\n",
        "from mcp import ClientSession, StdioServerParameters\n",
        "from mcp.client.stdio import stdio_client"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "e43229f3ad4f"
      },
      "source": [
        "### Load model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "id": "cf93d5f0ce00"
      },
      "outputs": [],
      "source": [
        "MODEL_ID = \"gemini-2.0-flash-001\"  # @param {type:\"string\"}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1e7b40e87f22"
      },
      "source": [
        "### Create an MCP weather server\n",
        "The [Server development guide](https://modelcontextprotocol.io/quickstart/server) shows the details of creation of an MCP Server.\n",
        "\n",
        "Here we modify the server sample to include three tools:\n",
        "\n",
        "- Get weather alert by state\n",
        "- Get forecast by coordinates\n",
        "- Get forecast by city name"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "id": "199464a1c1d4"
      },
      "outputs": [],
      "source": [
        "%%writefile server/weather_server.py\n",
        "import json\n",
        "from typing import Any, Dict, Optional\n",
        "import httpx\n",
        "from mcp.server.fastmcp import FastMCP\n",
        "from geopy.geocoders import Nominatim\n",
        "from geopy.exc import GeocoderTimedOut, GeocoderServiceError\n",
        "\n",
        "# Initialize FastMCP server\n",
        "mcp = FastMCP(\"weather\")\n",
        "\n",
        "# --- Configuration & Constants ---\n",
        "BASE_URL = \"https://api.weather.gov\"\n",
        "USER_AGENT = \"weather-agent\"\n",
        "REQUEST_TIMEOUT = 20.0\n",
        "GEOCODE_TIMEOUT = 10.0  # Timeout for geocoding requests\n",
        "\n",
        "# --- Shared HTTP Client ---\n",
        "http_client = httpx.AsyncClient(\n",
        "    base_url=BASE_URL,\n",
        "    headers={\"User-Agent\": USER_AGENT, \"Accept\": \"application/geo+json\"},\n",
        "    timeout=REQUEST_TIMEOUT,\n",
        "    follow_redirects=True,\n",
        ")\n",
        "\n",
        "# --- Geocoding Setup ---\n",
        "# Initialize the geocoder (Nominatim requires a unique user_agent)\n",
        "geolocator = Nominatim(user_agent=USER_AGENT)\n",
        "\n",
        "\n",
        "async def get_weather_response(endpoint: str) -> Optional[Dict[str, Any]]:\n",
        "    \"\"\"\n",
        "    Make a request to the NWS API using the shared client with error handling.\n",
        "    Returns None if an error occurs.\n",
        "    \"\"\"\n",
        "    try:\n",
        "        response = await http_client.get(endpoint)\n",
        "        response.raise_for_status()  # Raises HTTPStatusError for 4xx/5xx responses\n",
        "        return response.json()\n",
        "    except httpx.HTTPStatusError:\n",
        "        # Specific HTTP errors (like 404 Not Found, 500 Server Error)\n",
        "        return None\n",
        "    except httpx.TimeoutException:\n",
        "        # Request timed out\n",
        "        return None\n",
        "    except httpx.RequestError:\n",
        "        # Other request errors (connection, DNS, etc.)\n",
        "        return None\n",
        "    except json.JSONDecodeError:\n",
        "        # Response was not valid JSON\n",
        "        return None\n",
        "    except Exception:\n",
        "        # Any other unexpected errors\n",
        "        return None\n",
        "\n",
        "\n",
        "def format_alert(feature: Dict[str, Any]) -> str:\n",
        "    \"\"\"Format an alert feature into a readable string.\"\"\"\n",
        "    props = feature.get(\"properties\", {})  # Safer access\n",
        "    # Use .get() with default values for robustness\n",
        "    return f\"\"\"\n",
        "            Event: {props.get('event', 'Unknown Event')}\n",
        "            Area: {props.get('areaDesc', 'N/A')}\n",
        "            Severity: {props.get('severity', 'N/A')}\n",
        "            Certainty: {props.get('certainty', 'N/A')}\n",
        "            Urgency: {props.get('urgency', 'N/A')}\n",
        "            Effective: {props.get('effective', 'N/A')}\n",
        "            Expires: {props.get('expires', 'N/A')}\n",
        "            Description: {props.get('description', 'No description provided.').strip()}\n",
        "            Instructions: {props.get('instruction', 'No instructions provided.').strip()}\n",
        "            \"\"\"\n",
        "\n",
        "\n",
        "def format_forecast_period(period: Dict[str, Any]) -> str:\n",
        "    \"\"\"Formats a single forecast period into a readable string.\"\"\"\n",
        "    return f\"\"\"\n",
        "           {period.get('name', 'Unknown Period')}:\n",
        "             Temperature: {period.get('temperature', 'N/A')}°{period.get           ('temperatureUnit', 'F')}\n",
        "             Wind: {period.get('windSpeed', 'N/A')} {period.get('windDirection', 'N/A')}\n",
        "             Short Forecast: {period.get('shortForecast', 'N/A')}\n",
        "             Detailed Forecast: {period.get('detailedForecast', 'No detailed forecast            provided.').strip()}\n",
        "           \"\"\"\n",
        "\n",
        "\n",
        "# --- MCP Tools ---\n",
        "\n",
        "@mcp.tool()\n",
        "async def get_alerts(state: str) -> str:\n",
        "    \"\"\"\n",
        "    Get active weather alerts for a specific US state.\n",
        "\n",
        "    Args:\n",
        "        state: The two-letter US state code (e.g., CA, NY, TX). Case-insensitive.\n",
        "    \"\"\"\n",
        "    # Input validation and normalization\n",
        "    if not isinstance(state, str) or len(state) != 2 or not state.isalpha():\n",
        "        return \"Invalid input. Please provide a two-letter US state code (e.g., CA).\"\n",
        "    state_code = state.upper()\n",
        "\n",
        "    endpoint = f\"/alerts/active/area/{state_code}\"\n",
        "    data = await get_weather_response(endpoint)\n",
        "\n",
        "    if data is None:\n",
        "        # Error occurred during request\n",
        "        return f\"Failed to retrieve weather alerts for {state_code}.\"\n",
        "\n",
        "    features = data.get(\"features\")\n",
        "    if not features:  # Handles both null and empty list\n",
        "        return f\"No active weather alerts found for {state_code}.\"\n",
        "\n",
        "    alerts = [format_alert(feature) for feature in features]\n",
        "    return \"\\n---\\n\".join(alerts)\n",
        "\n",
        "\n",
        "@mcp.tool()\n",
        "async def get_forecast(latitude: float, longitude: float) -> str:\n",
        "    \"\"\"\n",
        "    Get the weather forecast for a specific location using latitude and longitude.\n",
        "\n",
        "    Args:\n",
        "        latitude: The latitude of the location (e.g., 34.05).\n",
        "        longitude: The longitude of the location (e.g., -118.25).\n",
        "    \"\"\"\n",
        "    # Input validation\n",
        "    if not (-90 <= latitude <= 90 and -180 <= longitude <= 180):\n",
        "        return \"Invalid latitude or longitude provided. Latitude must be between -90 and 90, Longitude between -180 and 180.\"\n",
        "\n",
        "    # NWS API requires latitude,longitude format with up to 4 decimal places\n",
        "    point_endpoint = f\"/points/{latitude:.4f},{longitude:.4f}\"\n",
        "    points_data = await get_weather_response(point_endpoint)\n",
        "\n",
        "    if points_data is None or \"properties\" not in points_data:\n",
        "        return f\"Unable to retrieve NWS gridpoint information for {latitude:.4f},{longitude:.4f}.\"\n",
        "\n",
        "    # Extract forecast URLs from the gridpoint data\n",
        "    forecast_url = points_data[\"properties\"].get(\"forecast\")\n",
        "\n",
        "    if not forecast_url:\n",
        "        return f\"Could not find the NWS forecast endpoint for {latitude:.4f},{longitude:.4f}.\"\n",
        "\n",
        "    # Make the request to the specific forecast URL\n",
        "    forecast_data = None\n",
        "    try:\n",
        "        response = await http_client.get(forecast_url)\n",
        "        response.raise_for_status()\n",
        "        forecast_data = response.json()\n",
        "    except httpx.HTTPStatusError:\n",
        "        pass  # Error handled by returning None below\n",
        "    except httpx.RequestError:\n",
        "        pass  # Error handled by returning None below\n",
        "    except json.JSONDecodeError:\n",
        "        pass  # Error handled by returning None below\n",
        "    except Exception:\n",
        "        pass  # Error handled by returning None below\n",
        "\n",
        "    if forecast_data is None or \"properties\" not in forecast_data:\n",
        "        return \"Failed to retrieve detailed forecast data from NWS.\"\n",
        "\n",
        "    periods = forecast_data[\"properties\"].get(\"periods\")\n",
        "    if not periods:\n",
        "        return \"No forecast periods found for this location from NWS.\"\n",
        "\n",
        "    # Format the first 5 periods\n",
        "    forecasts = [format_forecast_period(period) for period in periods[:5]]\n",
        "\n",
        "    return \"\\n---\\n\".join(forecasts)\n",
        "\n",
        "# --- NEW: get_forecast_by_city Tool ---\n",
        "@mcp.tool()\n",
        "async def get_forecast_by_city(city: str, state: str) -> str:\n",
        "    \"\"\"\n",
        "    Get the weather forecast for a specific US city and state by first finding its coordinates.\n",
        "\n",
        "    Args:\n",
        "        city: The name of the city (e.g., \"Los Angeles\", \"New York\").\n",
        "        state: The two-letter US state code (e.g., CA, NY). Case-insensitive.\n",
        "    \"\"\"\n",
        "    # --- Input Validation ---\n",
        "    if not city or not isinstance(city, str):\n",
        "        return \"Invalid city name provided.\"\n",
        "    if (\n",
        "        not state\n",
        "        or not isinstance(state, str)\n",
        "        or len(state) != 2\n",
        "        or not state.isalpha()\n",
        "    ):\n",
        "        return \"Invalid state code. Please provide the two-letter US state abbreviation (e.g., CA).\"\n",
        "\n",
        "    city_name = city.strip()\n",
        "    state_code = state.strip().upper()\n",
        "    # Construct a query likely to yield a US result\n",
        "    query = f\"{city_name}, {state_code}, USA\"\n",
        "\n",
        "    # --- Geocoding ---\n",
        "    location = None\n",
        "    try:\n",
        "        # Synchronous geocode call\n",
        "        location = geolocator.geocode(query, timeout=GEOCODE_TIMEOUT)\n",
        "\n",
        "    except GeocoderTimedOut:\n",
        "        return f\"Could not get coordinates for '{city_name}, {state_code}': The location service timed out.\"\n",
        "    except GeocoderServiceError:\n",
        "        return f\"Could not get coordinates for '{city_name}, {state_code}': The location service returned an error.\"\n",
        "    except Exception:\n",
        "        # Catch any other unexpected errors during geocoding\n",
        "        return f\"An unexpected error occurred while finding coordinates for '{city_name}, {state_code}'.\"\n",
        "\n",
        "    # --- Handle Geocoding Result ---\n",
        "    if location is None:\n",
        "        return f\"Could not find coordinates for '{city_name}, {state_code}'. Please check the spelling or try a nearby city.\"\n",
        "\n",
        "    latitude = location.latitude\n",
        "    longitude = location.longitude\n",
        "\n",
        "    # --- Reuse existing forecast logic with obtained coordinates ---\n",
        "    return await get_forecast(latitude, longitude)\n",
        "\n",
        "\n",
        "# --- Server Execution & Shutdown ---\n",
        "async def shutdown_event():\n",
        "    \"\"\"Gracefully close the httpx client.\"\"\"\n",
        "    await http_client.aclose()\n",
        "    # print(\"HTTP client closed.\") # Optional print statement if desired\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "    mcp.run(transport=\"stdio\")\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4ebb6082d936"
      },
      "source": [
        "### Gemini agent loop\n",
        "\n",
        "Within an MCP client session, this agent loop runs a multi-turn conversation loop with a Gemini model, handling tool calls via MCP server.\n",
        "\n",
        "This function orchestrates the interaction between a user prompt, a Gemini model capable of function calling, and a session object that provides and executes tools. It handles the cycle of:\n",
        "-  Gemini gets tool information from MCP client session\n",
        "-  Sending the user prompt (and conversation history) to the model.\n",
        "-  If the model requests tool calls, Gemini makes initial function calls to get structured data as per schema, and \n",
        "-  Sending the tool execution results back to the model.\n",
        "-  Repeating until the model provides a text response or the maximum number of tool execution turns is reached.\n",
        "-  Gemini generates final response based on tool responses and original query.\n",
        "  \n",
        "MCP integration with Gemini\n",
        "\n",
        "<img src=\"https://storage.googleapis.com/github-repo/generative-ai/gemini/mcp/mcp_tool_call.png\" alt=\"MCP with Gemini\" height=\"700\">"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ae7abde705ca"
      },
      "outputs": [],
      "source": [
        "# --- Configuration ---\n",
        "# Consider using a more recent/recommended model if available and suitable\n",
        "DEFAULT_MAX_TOOL_TURNS = 5  # Maximum consecutive turns for tool execution\n",
        "DEFAULT_INITIAL_TEMPERATURE = (\n",
        "    0.0  # Temperature for the first LLM call (more deterministic)\n",
        ")\n",
        "DEFAULT_TOOL_CALL_TEMPERATURE = (\n",
        "    1.0  # Temperature for LLM calls after tool use (potentially more creative)\n",
        ")\n",
        "\n",
        "\n",
        "# Make tool calls via MCP Server\n",
        "async def _execute_tool_calls(\n",
        "    function_calls: list[types.FunctionCall], session: ClientSession\n",
        ") -> list[types.Part]:\n",
        "    \"\"\"\n",
        "    Executes a list of function calls requested by the Gemini model via the session.\n",
        "\n",
        "    Args:\n",
        "        function_calls: A list of FunctionCall objects from the model's response.\n",
        "        session: The session object capable of executing tools via `call_tool`.\n",
        "\n",
        "    Returns:\n",
        "        A list of Part objects, each containing a FunctionResponse corresponding\n",
        "        to the execution result of a requested tool call.\n",
        "    \"\"\"\n",
        "    tool_response_parts: list[types.Part] = []\n",
        "    print(f\"--- Executing {len(function_calls)} tool call(s) ---\")\n",
        "\n",
        "    for func_call in function_calls:\n",
        "        tool_name = func_call.name\n",
        "        # Ensure args is a dictionary, even if missing or not a dict type\n",
        "        args = func_call.args if isinstance(func_call.args, dict) else {}\n",
        "        print(f\"  Attempting to call session tool: '{tool_name}' with args: {args}\")\n",
        "\n",
        "        tool_result_payload: dict[str, Any]\n",
        "        try:\n",
        "            # Execute the tool using the provided session object\n",
        "            # Assumes session.call_tool returns an object with attributes\n",
        "            # like `isError` (bool) and `content` (list of Part-like objects).\n",
        "            tool_result = await session.call_tool(tool_name, args)\n",
        "            print(f\"  Session tool '{tool_name}' execution finished.\")\n",
        "\n",
        "            # Extract result or error message from the tool result object\n",
        "            result_text = \"\"\n",
        "            # Check structure carefully based on actual `session.call_tool` return type\n",
        "            if (\n",
        "                hasattr(tool_result, \"content\")\n",
        "                and tool_result.content\n",
        "                and hasattr(tool_result.content[0], \"text\")\n",
        "            ):\n",
        "                result_text = tool_result.content[0].text or \"\"\n",
        "\n",
        "            if hasattr(tool_result, \"isError\") and tool_result.isError:\n",
        "                error_message = (\n",
        "                    result_text\n",
        "                    or f\"Tool '{tool_name}' failed without specific error message.\"\n",
        "                )\n",
        "                print(f\"  Tool '{tool_name}' reported an error: {error_message}\")\n",
        "                tool_result_payload = {\"error\": error_message}\n",
        "            else:\n",
        "                print(\n",
        "                    f\"  Tool '{tool_name}' succeeded. Result snippet: {result_text[:150]}...\"\n",
        "                )  # Log snippet\n",
        "                tool_result_payload = {\"result\": result_text}\n",
        "\n",
        "        except Exception as e:\n",
        "            # Catch exceptions during the tool call itself\n",
        "            error_message = f\"Tool execution framework failed: {type(e).__name__}: {e}\"\n",
        "            print(f\"  Error executing tool '{tool_name}': {error_message}\")\n",
        "            tool_result_payload = {\"error\": error_message}\n",
        "\n",
        "        # Create a FunctionResponse Part to send back to the model\n",
        "        tool_response_parts.append(\n",
        "            types.Part.from_function_response(\n",
        "                name=tool_name, response=tool_result_payload\n",
        "            )\n",
        "        )\n",
        "    print(f\"--- Finished executing tool call(s) ---\")\n",
        "    return tool_response_parts\n",
        "\n",
        "\n",
        "async def run_agent_loop(\n",
        "    prompt: str,\n",
        "    client: genai.Client,\n",
        "    session: ClientSession,\n",
        "    model_id: str = MODEL_ID,\n",
        "    max_tool_turns: int = DEFAULT_MAX_TOOL_TURNS,\n",
        "    initial_temperature: float = DEFAULT_INITIAL_TEMPERATURE,\n",
        "    tool_call_temperature: float = DEFAULT_TOOL_CALL_TEMPERATURE,\n",
        ") -> types.GenerateContentResponse:\n",
        "    \"\"\"\n",
        "    Runs a multi-turn conversation loop with a Gemini model, handling tool calls.\n",
        "\n",
        "    This function orchestrates the interaction between a user prompt, a Gemini\n",
        "    model capable of function calling, and a session object that provides\n",
        "    and executes tools. It handles the cycle of:\n",
        "    1. Sending the user prompt (and conversation history) to the model.\n",
        "    2. If the model requests tool calls, executing them via the `session`.\n",
        "    3. Sending the tool execution results back to the model.\n",
        "    4. Repeating until the model provides a text response or the maximum\n",
        "       number of tool execution turns is reached.\n",
        "\n",
        "    Args:\n",
        "        prompt: The initial user prompt to start the conversation.\n",
        "        client: An initialized Gemini GenerativeModel client object\n",
        "\n",
        "        session: An active session object responsible for listing available tools\n",
        "                 via `list_tools()` and executing them via `call_tool(tool_name, args)`.\n",
        "                 It's also expected to have an `initialize()` method.\n",
        "        model_id: The identifier of the Gemini model to use (e.g., \"gemini-2.0-flash\").\n",
        "        max_tool_turns: The maximum number of consecutive turns dedicated to tool calls\n",
        "                        before forcing a final response or exiting.\n",
        "        initial_temperature: The temperature setting for the first model call.\n",
        "        tool_call_temperature: The temperature setting for subsequent model calls\n",
        "                               that occur after tool execution.\n",
        "\n",
        "    Returns:\n",
        "        The final Response from the Gemini model after the\n",
        "        conversation loop concludes (either with a text response or after\n",
        "        reaching the max tool turns).\n",
        "\n",
        "    Raises:\n",
        "        ValueError: If the session object does not provide any tools.\n",
        "        Exception: Can potentially raise exceptions from the underlying API calls\n",
        "                   or session tool execution if not caught internally by `_execute_tool_calls`.\n",
        "    \"\"\"\n",
        "    print(\n",
        "        f\"Starting agent loop with model '{model_id}' and prompt: '{prompt[:100]}...'\"\n",
        "    )\n",
        "\n",
        "    # Initialize conversation history with the user's prompt\n",
        "    contents: list[types.Content] = [\n",
        "        types.Content(role=\"user\", parts=[types.Part(text=prompt)])\n",
        "    ]\n",
        "\n",
        "    # Ensure the session is ready (if needed)\n",
        "    if hasattr(session, \"initialize\") and callable(session.initialize):\n",
        "        print(\"Initializing session...\")\n",
        "        await session.initialize()\n",
        "    else:\n",
        "        print(\"Session object does not have an initialize() method, proceeding anyway.\")\n",
        "\n",
        "    # --- 1. Discover Tools from Session ---\n",
        "    print(\"Listing tools from session...\")\n",
        "    # Assumes session.list_tools() returns an object with a 'tools' attribute (list)\n",
        "    # Each item in the list should have 'name', 'description', and 'inputSchema' attributes.\n",
        "    session_tool_list = await session.list_tools()\n",
        "\n",
        "    if not session_tool_list or not session_tool_list.tools:\n",
        "        raise ValueError(\"No tools provided by the session. Agent loop cannot proceed.\")\n",
        "\n",
        "    # Convert session tools to the format required by the Gemini API\n",
        "    gemini_tool_config = types.Tool(\n",
        "        function_declarations=[\n",
        "            types.FunctionDeclaration(\n",
        "                name=tool.name,\n",
        "                description=tool.description,\n",
        "                parameters=tool.inputSchema,  # Assumes inputSchema is compatible\n",
        "            )\n",
        "            for tool in session_tool_list.tools\n",
        "        ]\n",
        "    )\n",
        "    print(\n",
        "        f\"Configured Gemini with {len(gemini_tool_config.function_declarations)} tool(s).\"\n",
        "    )\n",
        "\n",
        "    # --- 2. Initial Model Call ---\n",
        "    print(\"Making initial call to Gemini model...\")\n",
        "    current_temperature = initial_temperature\n",
        "    response = await client.aio.models.generate_content(\n",
        "        model=MODEL_ID,\n",
        "        contents=contents,  # Send updated history\n",
        "        config=types.GenerateContentConfig(\n",
        "            temperature=1.0,\n",
        "            tools=[gemini_tool_config],\n",
        "        ),  # Keep sending same config\n",
        "    )\n",
        "    print(\"Initial response received.\")\n",
        "\n",
        "    # Append the model's first response (potentially including function calls) to history\n",
        "    # Need to handle potential lack of candidates or content\n",
        "    if not response.candidates:\n",
        "        print(\"Warning: Initial model response has no candidates.\")\n",
        "        # Decide how to handle this - raise error or return the empty response?\n",
        "        return response\n",
        "    contents.append(response.candidates[0].content)\n",
        "\n",
        "    # --- 3. Tool Calling Loop ---\n",
        "    turn_count = 0\n",
        "    # Check specifically for FunctionCall objects in the latest response part\n",
        "    latest_content = response.candidates[0].content\n",
        "    has_function_calls = any(part.function_call for part in latest_content.parts)\n",
        "\n",
        "    while has_function_calls and turn_count < max_tool_turns:\n",
        "        turn_count += 1\n",
        "        print(f\"\\n--- Tool Turn {turn_count}/{max_tool_turns} ---\")\n",
        "\n",
        "        # --- 3.1 Execute Pending Function Calls ---\n",
        "        function_calls_to_execute = [\n",
        "            part.function_call for part in latest_content.parts if part.function_call\n",
        "        ]\n",
        "        tool_response_parts = await _execute_tool_calls(\n",
        "            function_calls_to_execute, session\n",
        "        )\n",
        "\n",
        "        # --- 3.2 Add Tool Responses to History ---\n",
        "        # Send back the results for *all* function calls from the previous turn\n",
        "        contents.append(\n",
        "            types.Content(role=\"function\", parts=tool_response_parts)\n",
        "        )  # Use \"function\" role\n",
        "        print(f\"Added {len(tool_response_parts)} tool response part(s) to history.\")\n",
        "\n",
        "        # --- 3.3 Make Subsequent Model Call with Tool Responses ---\n",
        "        print(\"Making subsequent API call to Gemini with tool responses...\")\n",
        "        current_temperature = tool_call_temperature  # Use different temp for follow-up\n",
        "        response = await client.aio.models.generate_content(\n",
        "            model=MODEL_ID,\n",
        "            contents=contents,  # Send updated history\n",
        "            config=types.GenerateContentConfig(\n",
        "                temperature=1.0,\n",
        "                tools=[gemini_tool_config],\n",
        "            ),\n",
        "        )\n",
        "        print(\"Subsequent response received.\")\n",
        "\n",
        "        # --- 3.4 Append latest model response and check for more calls ---\n",
        "        if not response.candidates:\n",
        "            print(\"Warning: Subsequent model response has no candidates.\")\n",
        "            break  # Exit loop if no candidates are returned\n",
        "        latest_content = response.candidates[0].content\n",
        "        contents.append(latest_content)\n",
        "        has_function_calls = any(part.function_call for part in latest_content.parts)\n",
        "        if not has_function_calls:\n",
        "            print(\n",
        "                \"Model response contains text, no further tool calls requested this turn.\"\n",
        "            )\n",
        "\n",
        "    # --- 4. Loop Termination Check ---\n",
        "    if turn_count >= max_tool_turns and has_function_calls:\n",
        "        print(\n",
        "            f\"Maximum tool turns ({max_tool_turns}) reached. Exiting loop even though function calls might be pending.\"\n",
        "        )\n",
        "    elif not has_function_calls:\n",
        "        print(\"Tool calling loop finished naturally (model provided text response).\")\n",
        "\n",
        "    # --- 5. Return Final Response ---\n",
        "    print(\"Agent loop finished. Returning final response.\")\n",
        "    return response"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "41a2e0fc0dfa"
      },
      "source": [
        "## 1. Use your own MCP Server\n",
        "### Start MCP client session with Custom MCP server and Gemini agent loop"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {
        "id": "b556fbb2be66"
      },
      "outputs": [],
      "source": [
        "# Create server parameters for stdio connection\n",
        "weather_server_params = StdioServerParameters(\n",
        "    command=\"python\",\n",
        "    # Make sure to update to the full absolute path to your weather_server.py file\n",
        "    args=[\"./server/weather_server.py\"],\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {
        "id": "bd8512e5a8c2"
      },
      "outputs": [],
      "source": [
        "async def run():\n",
        "    async with stdio_client(weather_server_params) as (read, write):\n",
        "        async with ClientSession(\n",
        "            read,\n",
        "            write,\n",
        "        ) as session:\n",
        "            # Test prompt\n",
        "            prompt = \"Tell me about weather in LA, CA\"\n",
        "            print(f\"Running agent loop with prompt: {prompt}\")\n",
        "            # Run agent loop\n",
        "            res = await run_agent_loop(prompt, client, session)\n",
        "            return res\n",
        "\n",
        "\n",
        "res = await run()\n",
        "print(res.text)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "f07ab426ca0c"
      },
      "source": [
        "## 2. Use pre-built MCP server\n",
        "\n",
        "There are [pre-built MCP servers](https://github.com/modelcontextprotocol/servers?tab=readme-ov-file) available for use.\n",
        "\n",
        "Here we use [this](https://github.com/LucasHild/mcp-server-bigquery) as an example.\n",
        "\n",
        "It has three tools:\n",
        "\n",
        "- execute-query: Executes a SQL query using BigQuery \n",
        "- list-tables: Lists all tables in the BigQuery database\n",
        "- describe-table: Describes the schema of a specific table\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "metadata": {
        "id": "e3f1d09b69ac"
      },
      "outputs": [],
      "source": [
        "# Create server parameters for stdio connection\n",
        "bq_server_params = StdioServerParameters(\n",
        "    command=\"uvx\",  # Executable\n",
        "    args=[\"mcp-server-bigquery\", \"--project\", PROJECT_ID, \"--location\", LOCATION],\n",
        "    env=None,  # Optional environment variables\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 8,
      "metadata": {
        "id": "46fd766cd2f3"
      },
      "outputs": [],
      "source": [
        "async def run():\n",
        "    async with stdio_client(bq_server_params) as (read, write):\n",
        "        async with ClientSession(\n",
        "            read,\n",
        "            write,\n",
        "        ) as session:\n",
        "            # Test prompt\n",
        "            prompt = \"Please list my BigQuery tables\"\n",
        "            print(f\"Running agent loop with prompt: {prompt}\")\n",
        "            # Run agent loop\n",
        "            res = await run_agent_loop(prompt, client, session)\n",
        "            return res\n",
        "\n",
        "\n",
        "res = await run()\n",
        "print(res.text)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3b513a8c9470"
      },
      "source": [
        "References:\n",
        "- https://modelcontextprotocol.io/introduction\n",
        "- https://github.com/philschmid/gemini-samples/blob/main/examples/gemini-mcp-example.ipynb\n",
        "- https://github.com/modelcontextprotocol/python-sdk \n",
        "  "
      ]
    }
  ],
  "metadata": {
    "colab": {
      "name": "intro_to_mcp.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
